# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Attention layer that can be used in sequence DNN/CNN models.

This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
Attention is formed by three tensors: Query, Key and Value.
"""
# pylint: disable=g-classes-have-attributes,g-direct-tensorflow-import

from keras.layers.attention.base_dense_attention import BaseDenseAttention
import tensorflow.compat.v2 as tf

from tensorflow.python.util.tf_export import keras_export


@keras_export('keras.layers.Attention')
class Attention(BaseDenseAttention):
  """Dot-product attention layer, a.k.a. Luong-style attention.

  Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
  shape `[batch_size, Tv, dim]` and `key` tensor of shape
  `[batch_size, Tv, dim]`. The calculation follows the steps:

  1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
     product: `scores = tf.matmul(query, key, transpose_b=True)`.
  2. Use scores to calculate a distribution with shape
     `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
  3. Use `distribution` to create a linear combination of `value` with
     shape `[batch_size, Tq, dim]`:
     `return tf.matmul(distribution, value)`.

  Args:
    use_scale: If `True`, will create a scalar variable to scale the attention
      scores.
    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
      that position `i` cannot attend to positions `j > i`. This prevents the
      flow of information from the future towards the past.
      Defaults to `False`.
    dropout: Float between 0 and 1. Fraction of the units to drop for the
      attention scores. Defaults to 0.0.
    score_mode: Function to use to compute attention scores, one of
      `{"dot", "concat"}`. `"dot"` refers to the dot product between the query
      and key vectors. `"concat"` refers to the hyperbolic tangent of the
      concatenation of the query and key vectors.

  Call Args:

    inputs: List of the following tensors:
      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
        given, will use `value` for both `key` and `value`, which is the
        most common case.
    mask: List of the following tensors:
      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
        If given, the output will be zero at the positions where
        `mask==False`.
      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
        If given, will apply the mask such that values at positions where
        `mask==False` do not contribute to the result.
    return_attention_scores: bool, it `True`, returns the attention scores
      (after masking and softmax) as an additional output argument.
    training: Python boolean indicating whether the layer should behave in
      training mode (adding dropout) or in inference mode (no dropout).

  Output:

    Attention outputs of shape `[batch_size, Tq, dim]`.
    [Optional] Attention scores after masking and softmax with shape
      `[batch_size, Tq, Tv]`.

  The meaning of `query`, `value` and `key` depend on the application. In the
  case of text similarity, for example, `query` is the sequence embeddings of
  the first piece of text and `value` is the sequence embeddings of the second
  piece of text. `key` is usually the same tensor as `value`.

  Here is a code example for using `Attention` in a CNN+Attention network:

  ```python
  # Variable-length int sequences.
  query_input = tf.keras.Input(shape=(None,), dtype='int32')
  value_input = tf.keras.Input(shape=(None,), dtype='int32')

  # Embedding lookup.
  token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
  # Query embeddings of shape [batch_size, Tq, dimension].
  query_embeddings = token_embedding(query_input)
  # Value embeddings of shape [batch_size, Tv, dimension].
  value_embeddings = token_embedding(value_input)

  # CNN layer.
  cnn_layer = tf.keras.layers.Conv1D(
      filters=100,
      kernel_size=4,
      # Use 'same' padding so outputs have the same shape as inputs.
      padding='same')
  # Query encoding of shape [batch_size, Tq, filters].
  query_seq_encoding = cnn_layer(query_embeddings)
  # Value encoding of shape [batch_size, Tv, filters].
  value_seq_encoding = cnn_layer(value_embeddings)

  # Query-value attention of shape [batch_size, Tq, filters].
  query_value_attention_seq = tf.keras.layers.Attention()(
      [query_seq_encoding, value_seq_encoding])

  # Reduce over the sequence axis to produce encodings of shape
  # [batch_size, filters].
  query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
      query_seq_encoding)
  query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
      query_value_attention_seq)

  # Concatenate query and document encodings to produce a DNN input layer.
  input_layer = tf.keras.layers.Concatenate()(
      [query_encoding, query_value_attention])

  # Add DNN layers, and create Model.
  # ...
  ```
  """

  def __init__(self, use_scale=False, score_mode='dot', **kwargs):
    super(Attention, self).__init__(**kwargs)
    self.use_scale = use_scale
    self.score_mode = score_mode
    if self.score_mode not in ['dot', 'concat']:
      raise ValueError(f'Received: score_mode={score_mode}. Acceptable values '
                       'are: ["dot", "concat"]')

  def build(self, input_shape):
    """Creates variable when `use_scale` is True or `score_mode` is `concat`."""
    if self.use_scale:
      self.scale = self.add_weight(
          name='scale',
          shape=(),
          initializer='ones',
          dtype=self.dtype,
          trainable=True)
    else:
      self.scale = None
    if self.score_mode == 'concat':
      self.concat_score_weight = self.add_weight(
          name='concat_score_weight',
          shape=(),
          initializer='ones',
          dtype=self.dtype,
          trainable=True)
    else:
      self.concat_score_weight = None
    super(Attention, self).build(input_shape)

  def _calculate_scores(self, query, key):
    """Calculates attention scores as a query-key dot product.

    Args:
      query: Query tensor of shape `[batch_size, Tq, dim]`.
      key: Key tensor of shape `[batch_size, Tv, dim]`.
    Returns:
      Tensor of shape `[batch_size, Tq, Tv]`.
    """
    if self.score_mode == 'dot':
      scores = tf.matmul(query, key, transpose_b=True)
      if self.scale is not None:
        scores *= self.scale
    elif self.score_mode == 'concat':
      # Reshape tensors to enable broadcasting.
      # Reshape into [batch_size, Tq, 1, dim].
      q_reshaped = tf.expand_dims(query, axis=-2)
      # Reshape into [batch_size, 1, Tv, dim].
      k_reshaped = tf.expand_dims(key, axis=-3)
      if self.scale is not None:
        scores = self.concat_score_weight * tf.reduce_sum(
            tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1)
      else:
        scores = self.concat_score_weight * tf.reduce_sum(
            tf.tanh(q_reshaped + k_reshaped), axis=-1)

    return scores

  def get_config(self):
    config = {'use_scale': self.use_scale, 'score_mode': self.score_mode}
    base_config = super(Attention, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
